import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from evaluate.evaluate_distrib_rl import evaluate_playing_style
from generic.model_util import get_distrib_q_model_save_path

from agent import SportsAgent
from generic.data_util import load_config, read_args, ICEHOCKEY_ACTIONS, divide_dataset_according2date


def test(args):
    episode_num = 'testing'
    test_num_tau = 256
    test_num_supp = 256
    test_gamma = 1
    test_train_rate = 0.8
    max_trace_length = 3
    apply_dynamic_trace_length = False
    test_apply_rnn = True
    test_apply_resnet = True
    test_cut_at_goal = True
    action_only = False
    gda_fitting_target = "QValues"
    load_dqn_episode = 'best'  # 'best'  # 10000
    date = 'Oct-31-2021'  # 'Oct-31-2021', 'Nov-19-2021', 'Oct-29-2021', 'Nov-04-2021'
    # best_label = 'best'
    debug_msg = ''
    sanity_check_msg = None

    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None
    config['general']['model']['apply_rnn'] = test_apply_rnn
    config['general']['model']['apply_resnet'] = test_apply_resnet
    config['general']['model']['num_tau'] = test_num_tau
    config['general']['model']['num_supp'] = test_num_supp
    config['general']['training']['gamma'] = test_gamma
    config['general']['training']['cut_at_goal'] = test_cut_at_goal
    config['general']['training']['train_rate'] = test_train_rate
    config['general']['model']['apply_dynamic_trace_length'] = apply_dynamic_trace_length
    config['general']['model']['max_trace_length'] = max_trace_length
    config['general']['use_cuda'] = False

    agent = SportsAgent(config=config, log_file=log_file)
    model_save_mother_dir = get_distrib_q_model_save_path(agent=agent,
                                                          date_label=date,
                                                          debug_msg=debug_msg)
    # model_save_mother_dir += best_label
    dqn_load_from_path = model_save_mother_dir + '/saved_model_{0}'.format(load_dqn_episode)
    _, episode_num, _, _, _ = agent.load_pretrained_model(load_from=dqn_load_from_path,
                                                          load_optim=False,
                                                          log_file=log_file)
    evaluate_playing_style(agent, debug_mode)


if __name__ == "__main__":
    args = read_args()
    test(args)
